

import torch
import torch.utils.data 
from torch.nn import functional as F
import pytorch_lightning as pl
from pytorch_lightning.callbacks import ModelCheckpoint, Callback
from pytorch_lightning import loggers as pl_loggers

import os
import json
import time
from tqdm.auto import tqdm
from einops import rearrange, reduce
import numpy as np
import trimesh


from models import * 
from utils import mesh, evaluate, reconstruct
from diff_utils.helpers import * 

def vis_recon(test_dataloader, sdf_model, vae_model, recon_dir, take_mod=False, calc_cd=False):
    resolution = 64
    recon_batch = 2**20
    
    with torch.no_grad():
        if args.evaluate:
            point_clouds, pc_paths = test_dataloader.get_all_files()

            point_clouds = torch.stack(point_clouds) 

            recon_meshes = torch.empty(*point_clouds.shape)

            
            for idx, path in enumerate(pc_paths):
                
                cls_name = path.split("/")[-3]
                mesh_name = path.split("/")[-2]
                mesh_filename = os.path.join(recon_dir, "{}/{}/reconstruct".format(cls_name, mesh_name))
                recon_mesh = trimesh.load(os.path.join(os.getcwd(), mesh_filename)+".ply")
                recon_pc, _ = trimesh.sample.sample_surface(recon_mesh, point_clouds.shape[1])
                recon_meshes[idx] = torch.from_numpy(recon_pc)

            print("ref, recon shapes: ", recon_meshes.shape, point_clouds.shape) 
            results = evaluation_metrics.compute_all_metrics(recon_meshes.float(), point_clouds.float(), accelerated_cd=False)
            for k,v in results.items():
                print(k, ": ", v)

        elif args.take_mod and not args.sample:

            lst = []
            if args.mod_folder:
                files = os.listdir(args.mod_folder)
                for f in files:
                    if os.path.isfile(os.path.join(args.mod_folder, f)) and f[-4:]=='.txt':
                        lst.append(os.path.join(args.mod_folder, f))
            else:
                lst = args.take_mod

            for idx, m in enumerate(lst):
                latent = torch.from_numpy(np.loadtxt(m)).float().cuda()
                recon = vae_model.decode(latent) 
                name = args.output_name if args.output_name else "mod_recon"
                name += "{}".format(idx)
                os.makedirs(os.path.join(recon_dir, "modulation_recon"), exist_ok=True)
                mesh_filename = os.path.join(recon_dir, "modulation_recon", name)
                mesh.create_mesh(sdf_model, recon, mesh_filename, resolution, recon_batch, from_plane_features=True)
        elif args.sample:
            recon = vae_model.sample(num_samples=1)
            name = args.output_name if args.output_name else "mod_recon"
            os.makedirs(os.path.join(recon_dir, "modulation_recon"), exist_ok=True)
            mesh_filename = os.path.join(recon_dir, "modulation_recon", name)
            mesh.create_mesh(sdf_model, recon, mesh_filename, resolution, recon_batch, from_plane_features=True)
        else:
            for idx, data in enumerate(test_dataloader): 

                data, filename = data 
                filename = filename[0]
                random_flip = specs.get("random_flip", False)

                cls_name = filename.split("/")[-3]
                mesh_name = filename.split("/")[-2]
                outdir = os.path.join(recon_dir, "{}/{}".format(cls_name, mesh_name))
                os.makedirs(outdir, exist_ok=True)
                mesh_filename = os.path.join(outdir, "reconstruct")
               
                plane_features = sdf_model.pointnet.get_plane_features(data.cuda())  
                plane_features = torch.cat(plane_features, dim=1) 
                recon = vae_model.generate(plane_features) 

                mesh.create_mesh(sdf_model, recon, mesh_filename, resolution, recon_batch, from_plane_features=True)

                
                if calc_cd:
                    evaluate_filename = os.path.join(recon_dir, "cd.csv")
                    mesh_log_name = cls_name+"/"+mesh_name
                    try:
                        evaluate.main(data, mesh_filename, evaluate_filename, mesh_log_name)
                    except Exception as e:
                        print(e)

def filter_threshold(mesh, gt_pc, threshold): 
    cd = evaluate.main(gt_pc, mesh, None, None, return_value=True, prioritize_cov=True)
    print('cd:', cd)
    return cd <= threshold



def extract_latents(test_dataloader, sdf_model, vae_model, save_dir):
    
    latent_dir = os.path.join(save_dir, "modulations")
    os.makedirs(latent_dir, exist_ok=True)
    with torch.no_grad():
        for idx, data in enumerate(test_dataloader): 

            data, filename = data
            filename = filename[0] 
            cls_name = filename.split("/")[-3]
            mesh_name = filename.split("/")[-2]

            
            saved_mesh = os.path.join(recon_dir, "{}/{}/reconstruct".format(cls_name, mesh_name))
            gt_pc = data
            try:
                if not filter_threshold(saved_mesh, gt_pc, 0.0022):
                    continue

                outdir = os.path.join(latent_dir, "{}/{}".format(cls_name, mesh_name))
                os.makedirs(outdir, exist_ok=True)

                random_flip = specs.get("random_flip", False)
                if random_flip:
                    flip_axes = torch.tensor([[1,1,1],[-1,1,1],[1,-1,1],[1,1,-1]], device=data.device)
                    for idx, axis in enumerate(flip_axes):
                        flipped_data = data * axis.unsqueeze(0).repeat(data.shape[0], data.shape[1], 1)
                
                        features = sdf_model.pointnet.get_plane_features(flipped_data.cuda())
                        
                        features = torch.cat(features, dim=1)
                        latent = vae_model.get_latent(features)

                        
                        np.savetxt(os.path.join(outdir, "latent_{}.txt".format(idx)), latent.cpu().numpy())
                
                else:
                    features = sdf_model.pointnet.get_plane_features(data.cuda())
                    
                    features = torch.cat(features, dim=1)
                    latent = vae_model.get_latent(features)

                    
                    np.savetxt(os.path.join(outdir, "latent.txt"), latent.cpu().numpy())

            except Exception as e:
                print(e)

